package test;

import java.io.BufferedReader;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;

import InputRead.ReadDataFile;
import Utilities.MathUtil;

import dtelements.Attribute;
import dtelements.AttributeSet;
import dtelements.DecisionTree;
import dtelements.Example;
import dtelements.Node;

public class HW2 {

	private int noOfAttributes;
	//private AttributeSet[] remainingAttributes;
	final static boolean debugMode = false;
	/**
	 * @param args
	 */
	public static void main(String[] args) {

		new HW2();
	}
	
	public HW2()
	{
		//run_main();
		test();
	}
	
	public void test()
	{
		/////////// INPUT ///////////
		int modeFlag = 4;
		String trainFileName = "files/train.txt";
		String tuneFileName = "files/tune.txt";
		String testFileName = "files/test.txt";
		
		
		/////////////////////////////
		

		ArrayList<Attribute> allAttributes = new ArrayList<Attribute>();
		String[] classLabels = new String[2]; 	// classLabels[0] : label of the negative example
												// classLabels[1] : label of the positive example
		
		ArrayList<Example> allTrainExamples = new ArrayList<Example>();
		ArrayList<Example> allTuneExamples = new ArrayList<Example>();
		ArrayList<Example> allTestExamples = new ArrayList<Example>();
		
		ReadDataFile.readFile(allAttributes, allTrainExamples, classLabels, trainFileName);
		ReadDataFile.readFile(new ArrayList<Attribute>(), allTuneExamples, new String[2], tuneFileName);
		ReadDataFile.readFile(new ArrayList<Attribute>(), allTestExamples, new String[2], testFileName);
		
		
		try {
			BufferedReader read = new BufferedReader(new FileReader("files/rtrain.txt"));
			int noOfNegatives = 0;
			int noOfPositives = 0;
			
			
			String line;
			while( (line = read.readLine()) != null)
			{
				String[] values = line.split(",");
				
				for(int i = 0; i < values.length; i++)
				{
					boolean c1 = (values[1].equals("1st"));
					boolean c2 = (values[3].equals("female"));
					boolean c3 = (values[2].equals("child"));
					if(c1 && c2)
					{
						if(values[0].equals(classLabels[0]))
							noOfNegatives++;
						else
							noOfPositives++;
							
					}
				}
			}
			
			System.out.println(classLabels[0] + ": " + noOfNegatives);
			System.out.println(classLabels[1] + ": " + noOfPositives);
		} catch (FileNotFoundException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		} catch (IOException e) {
			// TODO Auto-generated catch block
			e.printStackTrace();
		}
		
		
		
		
	}
	
	@SuppressWarnings("unchecked")
	public void run_main()
	{
		/////////// INPUT ///////////
		int modeFlag = 4;
		String trainFileName = "files/train.txt";
		String tuneFileName = "files/tune.txt";
		String testFileName = "files/test.txt";
		
		
		/////////////////////////////
		

		ArrayList<Attribute> allAttributes = new ArrayList<Attribute>();
		String[] classLabels = new String[2]; 	// classLabels[0] : label of the negative example
												// classLabels[1] : label of the positive example
		
		ArrayList<Example> allTrainExamples = new ArrayList<Example>();
		ArrayList<Example> allTuneExamples = new ArrayList<Example>();
		ArrayList<Example> allTestExamples = new ArrayList<Example>();
		
		ReadDataFile.readFile(allAttributes, allTrainExamples, classLabels, trainFileName);
		ReadDataFile.readFile(new ArrayList<Attribute>(), allTuneExamples, new String[2], tuneFileName);
		ReadDataFile.readFile(new ArrayList<Attribute>(), allTestExamples, new String[2], testFileName);
		
		this.noOfAttributes = allAttributes.size();
		//this.remainingAttributes = new AttributeSet[allAttributes.size()];
		//this.remainingAttributes[0] = (new AttributeSet(allAttributes));
		
		
		if(modeFlag == 0)
		{
			for(Attribute attribute : allAttributes)
			{
				System.out.println(attribute.getName() + " " + MutualInformation(attribute,allTrainExamples));	
			}
			System.exit(0);
			
		} // end of modeFlag=0
		
		if (modeFlag == 1 || modeFlag == 2 || modeFlag == 3 || modeFlag == 4) 
		{

			ArrayList<Example> selectedExamples = new ArrayList();
			
			for(Example example : allTrainExamples)
			{
				selectedExamples.add(example);
			}
			
			Attribute mostImportantAttribute = maxMIAttribute(allAttributes,selectedExamples);
			
			ArrayList<Attribute> remainingAttributes = new ArrayList<Attribute>();
			
			for(Attribute a : allAttributes)
			{
				if(!a.equals(mostImportantAttribute))
				{
					remainingAttributes.add(a);
				}
			}
			
			Node root = new Node(mostImportantAttribute, selectedExamples,remainingAttributes);
			BuildDecisionTree(root);
			DecisionTree dt = new DecisionTree(root, classLabels);
			
			if(modeFlag == 1)
			{
				dt.Report(dt.getRoot());
				if(debugMode)
				{
					System.out.println("ACCURACY on Test Set: " + Accuracy(dt,allTestExamples));
				}
				System.exit(0);
			}

			if(modeFlag == 2)
			{
				ArrayList<String> predictions = Predictions(dt,allTestExamples,classLabels);
				PrintArrayList(predictions);
				if(debugMode)
				{
					System.out.println("ACCURACY on Test Set: " + Accuracy(dt,allTestExamples));
				}
				System.exit(0);
			}
			
			
			if(debugMode)
			{
				System.out.println("Initial Decision Tree:");
				dt.Report(dt.getRoot());
				System.out.println("Initial Decision Tree Accuracy on Tune Set: " + Accuracy(dt,allTuneExamples));
				System.out.println("Initial Decision Tree Accuracy on Test Set: " + Accuracy(dt,allTestExamples));
			}
			
			double tuneAccuracy_old;
			double maxTuneAccuracy;
			Node nodeToPrune = null;
			
			maxTuneAccuracy = tuneAccuracy_old = Accuracy(dt, allTuneExamples);
			
			do { 
				if(nodeToPrune != null)
				{
					nodeToPrune.prune(new ArrayList<Node>(), new ArrayList<String>());
				}
				
				tuneAccuracy_old = maxTuneAccuracy;
				
				ArrayList<Node> allInternalNodes = dt.getAllInternalNodes();
				
				for(Node node : allInternalNodes)
				{
					ArrayList<Node> edges = new ArrayList<Node>();
					ArrayList<String> edgeLabels = new ArrayList<String>();
					
					if(debugMode)
					{
						System.out.println("--------------------------------------------------");
						System.out.println("Node to prune: " + node.getAttribute().getName() + "\n\n");
						
					}
					
					node.prune(edges, edgeLabels);
					
					double tuneAccuracy_new = Accuracy(dt,allTuneExamples);
					
					if(debugMode)
					{
						dt.Report(dt.getRoot());
						System.out.println("ACCURACY on Tune Set: " + tuneAccuracy_new);
						System.out.println("ACCURACY on Test Set: " + Accuracy(dt,allTestExamples));
					}
					
					
					if(tuneAccuracy_new >= maxTuneAccuracy)
					{
						nodeToPrune = node;
						maxTuneAccuracy = tuneAccuracy_new;
					}
					
					for(int i = 0; i < edges.size(); i++)
					{
						node.addEdge(edges.get(i), edgeLabels.get(i));
					}
				}			
			} while (maxTuneAccuracy > tuneAccuracy_old);
			
			
			
			if(nodeToPrune != null)
			{
				if(debugMode)
				{
					System.out.println("-----------------------------------------------------------------------------------------");
					System.out.println("The Node that gives best accuracy after pruning: " + nodeToPrune.getAttribute().getName());
				}
				nodeToPrune.prune(new ArrayList<Node>(), new ArrayList<String>());
			}
			
			
			
			if(modeFlag == 3 || debugMode)
			{
				if(debugMode)
				{
					System.out.println("The tree that gives the best accuracy:");
					
				}
				dt.Report(dt.getRoot());
				
				if(debugMode)
				{
					System.out.println();
					System.out.println("Accuracy on Tune Set: " + Accuracy(dt,allTuneExamples));
					System.out.println("Accuracy on Test Set: " + Accuracy(dt,allTestExamples));
				}
				
				if(modeFlag==3)
				{
					System.exit(0);
				}
			}
			
			if(modeFlag == 4)
			{
				if(debugMode)
				{
					System.out.println("---------------------------------------");
					System.out.println("Predictions:");
				}
				
				ArrayList<String> predictions = Predictions(dt,allTestExamples, classLabels);
				PrintArrayList(predictions);
				
				if(debugMode)
				{
					System.out.println("Accuracy on test Set: " + Accuracy(dt,allTestExamples));
				}
				System.exit(0);
			}			
		}
	}
	
	
	private void PrintArrayList(ArrayList<String> list)
	{
		for(String line : list)
		{
			System.out.println(line);
		}
	}
	
	public ArrayList<String> Predictions(DecisionTree dt, ArrayList<Example> DataSet, String[] classLabels)
	{
		ArrayList<String> predictions = new ArrayList<String>();
		
		for(Example example : DataSet)
		{
			boolean prediction = dt.determineClass(example.getAttributes(), example.getValues());
			
			if(prediction)
			{
				predictions.add(classLabels[1]);
			}
			else
			{
				predictions.add(classLabels[0]);
			}
		}
		
		return predictions;
	}
	
	
	public double Accuracy(DecisionTree dt, ArrayList<Example> DataSet)
	{
		int noOfCorrectPredictions = 0;
		
		for(Example example : DataSet)
		{
			boolean prediction = dt.determineClass(example.getAttributes(), example.getValues());
			boolean correctClass = example.classValue();
			
			if(prediction == correctClass)
			{
				noOfCorrectPredictions++;
			}
		}
		
		return ((double)noOfCorrectPredictions) / ((double) DataSet.size());
	}
	
	
	private void BuildDecisionTree(Node node)
	{
		if(node.getExamples().size() == 0)
		{
			
		}
		else if (node.SingleClassification())
		{
			
		}
		else if (node.getDepth()  >= this.noOfAttributes-1)
		{
			
		}
		else
		{
			
			for(String attributeValue : node.getAttribute().getDomain())
			{
				ArrayList<Example> selectedExamples = new ArrayList<Example>();
				
				for(Example example : node.getExamples())
				{
					if(example.getAttributeValue(node.getAttribute()).equals(attributeValue))
					{
						selectedExamples.add(example);					
					}
				}
				
				if((selectedExamples.size() > 0) && (!Node.SingleClassification(selectedExamples)))
				{
					//AttributeSet nextAttributeSet = this.remainingAttributes[node.getDepth()].clone();
					//nextAttributeSet.remove(node.getAttribute());
					//this.remainingAttributes[node.getDepth()+1] = nextAttributeSet;

					ArrayList<Attribute> remainingAttributes = new ArrayList<Attribute>();
					
					Attribute mostImportantAttribute = maxMIAttribute(node.getRemainingAttributes(), selectedExamples);
					
					for(Attribute a : node.getRemainingAttributes())
					{
						if(!a.equals(mostImportantAttribute))
						{
							remainingAttributes.add(a);
						}
					}
										
					Node nextNode = new Node(mostImportantAttribute, selectedExamples, remainingAttributes);
					node.addEdge(nextNode, attributeValue);
					BuildDecisionTree(nextNode);
					
				}
			}
		}
	}
	
	private Attribute maxMIAttribute(ArrayList<Attribute> selectedAttributes, ArrayList<Example> selectedExamples)
	{
		double maxMutualInfo = -1;
		Attribute theAttribute = null;
		
		for(Attribute attribute : selectedAttributes)
		{
			double mutualInfo = MutualInformation(attribute, selectedExamples);
			if(mutualInfo > maxMutualInfo)
			{
				theAttribute = attribute;
				maxMutualInfo = mutualInfo;
			}
		}
		
		return theAttribute;
	}
	
	
	//I(Y;X)
	public double MutualInformation(Attribute attribute, ArrayList<Example> examples)
	{
		return Entropy(examples) - Entropy(attribute,examples);
	}
	
	// H(Y)
	public double Entropy(ArrayList<Example> allExamples)
	{
		//System.out.println("H(Y)=" + (MathUtil.MinusXLogX(Prob(true,allExamples)) + (MathUtil.MinusXLogX(Prob(false,allExamples)))));
		return (MathUtil.MinusXLogX(Prob(true,allExamples))) + (MathUtil.MinusXLogX(Prob(false,allExamples)));
	}
	
	// P(Class = classValue | attribute = attributeValue)
	// P(Y = y | X = x)
	public double Prob(boolean classValue, Attribute attribute, String attributeValue, ArrayList<Example> allExamples)
	{
		ArrayList<Example> selectedExamples = new ArrayList<Example>(); // examples where attribute = attributeValue
		
		for(Example example : allExamples)
		{
			if(example.getAttributeValue(attribute).equals(attributeValue))
			{
				selectedExamples.add(example);
			}
		}
		
		double classCount = 0; // #times that Class = classValue
		for(Example example : selectedExamples)
		{
			if(example.classValue() == classValue)
			{
				classCount++;
			}
		}
		
		if(selectedExamples.size() == 0)
		{
			return 0;
		}
		
		//System.out.println("H(Y=" + classValue+"|X="+attributeValue+") = " + (classCount / selectedExamples.size()));
		
		return classCount / selectedExamples.size();
	}
	
	
	//H(Class | attribute)
	//H(Y | X)
	public double Entropy(Attribute attribute, ArrayList<Example> allExamples)
	{
		double entropy = 0;
		for(String attributeValue : attribute.getDomain())
		{
			entropy += (Prob(attribute, attributeValue, allExamples) *  Entropy(attribute, attributeValue, allExamples));
		}
		
	
		return entropy;
	}
	
	// H(Class | attribute = attributeValue)
	// H(Y | X = x)
	public double Entropy(Attribute attribute, String attributeValue, ArrayList<Example> allExamples)
	{
		double probClassPositive = Prob(true, attribute,attributeValue, allExamples); // P(Y = t | X = x)
		double probClassNegative = Prob(false, attribute,attributeValue, allExamples);// P(Y = f | X = x)
		
		//System.out.println("H(Y|X=" + attributeValue + ") = " + (MathUtil.MinusXLogX(probClassNegative) + MathUtil.MinusXLogX(probClassPositive)));
		
		return MathUtil.MinusXLogX(probClassNegative) + MathUtil.MinusXLogX(probClassPositive);
	}
	
	
	//P(attribuate = attributeValue)
	//P(X=x)
	public double Prob(Attribute attribute, String attributeValue, ArrayList<Example> allExamples)
	{
		double count = 0;
		for(Example example : allExamples)
		{
			if(example.getAttributeValue(attribute).equals(attributeValue))
			{
				count++;
			}
		}
		
		//System.out.println("P(X=" + attributeValue + ")=" + (count / allExamples.size()));
		
		return count / allExamples.size();
	}
	
	//P(class = classValue)
	//P(Y=y)
	public double Prob(boolean classValue, ArrayList<Example> allExamples)
	{
		double count = 0;
		for(Example example : allExamples)
		{
			if(example.classValue() == classValue)
			{
				count++;
			}
		}
		
		//System.out.println("P(Y=" + classValue + ")=" + (count / allExamples.size()));
		
		return count / allExamples.size();
	}
	
	public void reportAll(ArrayList<Attribute> allAttributes, ArrayList<Example> allExamples, String[] classLabels)
	{
		System.out.println("Negative Class Label: " + classLabels[0]);
		System.out.println("Positive Class Label: " + classLabels[1]);
		System.out.println();
		
		System.out.println("Attributes:");
		
		for(Attribute attribute : allAttributes)
		{
			System.out.println(attribute.toString());
		}

		System.out.println();
		System.out.println("Examples");
		
		for(Example example : allExamples)
		{
			System.out.println(example.toString());
		}
		
		
	}
	

}
